CategoricalCrossentropy VS SparseCategoricalCrossentropy

Reason is the light and the light of life.

Jerry Su Dec 09, 2020 3 mins

SparseCategoricalCrossentropy

Use this crossentropy loss function when there are two or more label classes. We expect labels to be provided in a one_hot representation. If you want to provide labels as integers, please use SparseCategoricalCrossentropy loss. There should be # classes floating point values per feature.

https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy

tf.keras.losses.Reduction

class TFTokenClassificationLoss:
    """
    Loss function suitable for token classification.

    .. note::
        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """

    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        # make sure only labels that are not equal to -100
        # are taken into account as loss
        if tf.math.reduce_any(labels == -1):
            warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
            active_loss = tf.reshape(labels, (-1,)) != -1
        else:
            active_loss = tf.reshape(labels, (-1,)) != -100
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)
import tensorflow as tf
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
labels = tf.constant([[[1, 2, 0, -100], [2, 1, -100, -100]]])
labels
<tf.Tensor: shape=(1, 2, 4), dtype=int32, numpy=
array([[[   1,    2,    0, -100],
        [   2,    1, -100, -100]]], dtype=int32)>
logits = tf.constant([[[2, 4, 5, 0], [1, 4, 0, 0]]])
logits
<tf.Tensor: shape=(1, 2, 4), dtype=int32, numpy=
array([[[2, 4, 5, 0],
        [1, 4, 0, 0]]], dtype=int32)>
tf.reshape(labels, (-1,))
<tf.Tensor: shape=(8,), dtype=int32, numpy=array([   1,    2,    0, -100,    2,    1, -100, -100], dtype=int32)>
active_loss = tf.reshape(labels, (-1,)) != -100
active_loss
<tf.Tensor: shape=(8,), dtype=bool, numpy=array([ True,  True,  True, False,  True,  True, False, False])>
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1,)), active_loss)
reduced_logits
<tf.Tensor: shape=(5,), dtype=int32, numpy=array([2, 4, 5, 1, 4], dtype=int32)>
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
labels
<tf.Tensor: shape=(5,), dtype=int32, numpy=array([1, 2, 0, 2, 1], dtype=int32)>
loss_fn(y_true=labels, y_pred=reduced_logits)
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-9-de0f594293dd> in <module>
----> 1 loss_fn(y_true=labels, y_pred=reduced_logits)


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/keras/losses.py in __call__(self, y_true, y_pred, sample_weight)
    147     with K.name_scope(self._name_scope), graph_ctx:
    148       ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
--> 149       losses = ag_call(y_true, y_pred)
    150       return losses_utils.compute_weighted_loss(
    151           losses, sample_weight, reduction=self._get_reduction())


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    253       try:
    254         with conversion_ctx:
--> 255           return converted_call(f, args, kwargs, options=options)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e, 'ag_error_metadata'):


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    530 
    531   if not options.user_requested and conversion.is_whitelisted(f):
--> 532     return _call_unconverted(f, args, kwargs, options)
    533 
    534   # internal_convert_user_code is for example turned off when issuing a dynamic


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
    337 
    338   if kwargs is not None:
--> 339     return f(*args, **kwargs)
    340   return f(*args)
    341


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/keras/losses.py in call(self, y_true, y_pred)
    251           y_pred, y_true)
    252     ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
--> 253     return ag_fn(y_true, y_pred, **self._fn_kwargs)
    254 
    255   def get_config(self):


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/keras/losses.py in sparse_categorical_crossentropy(y_true, y_pred, from_logits, axis)
   1564   y_pred = ops.convert_to_tensor_v2(y_pred)
   1565   y_true = math_ops.cast(y_true, y_pred.dtype)
-> 1566   return K.sparse_categorical_crossentropy(
   1567       y_true, y_pred, from_logits=from_logits, axis=axis)
   1568


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/keras/backend.py in sparse_categorical_crossentropy(target, output, from_logits, axis)
   4783           labels=target, logits=output)
   4784   else:
-> 4785     res = nn.sparse_softmax_cross_entropy_with_logits_v2(
   4786         labels=target, logits=output)
   4787


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/ops/nn_ops.py in sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name)
   4173       of the labels is not equal to the rank of the logits minus one.
   4174   """
-> 4175   return sparse_softmax_cross_entropy_with_logits(
   4176       labels=labels, logits=logits, name=name)
   4177


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a


/opt/conda/envs/blog/lib/python3.8/site-packages/tensorflow/python/ops/nn_ops.py in sparse_softmax_cross_entropy_with_logits(_sentinel, labels, logits, name)
   4086     if (static_shapes_fully_defined and
   4087         labels_static_shape != logits.get_shape()[:-1]):
-> 4088       raise ValueError("Shape mismatch: The shape of labels (received %s) "
   4089                        "should equal the shape of logits except for the last "
   4090                        "dimension (received %s)." % (labels_static_shape,


ValueError: Shape mismatch: The shape of labels (received (5,)) should equal the shape of logits except for the last dimension (received (1, 5)).


Read more:

Related posts: